Skip to content

Conversation

@quic-rb10
Copy link

The VectorToLLVM pass currently includes an option (force32BitVectorIndices) to override vector indices. However, it lacks a mechanism to generically override the indexBitWidth. To address this, we are introducing a new indexBitWidth option for the VectorToLLVM pass, allowing users to specify the bit width of the index type.

@github-actions
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@quic-rb10
Copy link
Author

cc: @javedabsar1

@llvmbot llvmbot added the mlir label Feb 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-mlir

Author: None (quic-rb10)

Changes

The VectorToLLVM pass currently includes an option (force32BitVectorIndices) to override vector indices. However, it lacks a mechanism to generically override the indexBitWidth. To address this, we are introducing a new indexBitWidth option for the VectorToLLVM pass, allowing users to specify the bit width of the index type.


Patch is 62.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128154.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+3)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+19-22)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+2)
  • (added) mlir/test/Conversion/VectorToLLVM/vector-index-bitwidth.mlir (+674)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+15-15)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+1-1)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..20eb6392daf49 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1414,6 +1414,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "vector::VectorTransformsOptions",
            /*default=*/"vector::VectorTransformsOptions()",
            "Options to lower some operations like contractions and transposes.">,
+    Option<"indexBitwidth", "index-bitwidth", "unsigned",
+           /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+           "Bitwidth of the index type, 0 to use size of machine word">,
   ];
 }
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c9d637ce81f93..1f8a222282aac 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -49,10 +49,9 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
                        int64_t pos) {
   assert(rank > 0 && "0-D vector corner case should have been handled already");
   if (rank == 1) {
-    auto idxType = rewriter.getIndexType();
+    auto idxType = typeConverter.convertType(rewriter.getIndexType());
     auto constant = rewriter.create<LLVM::ConstantOp>(
-        loc, typeConverter.convertType(idxType),
-        rewriter.getIntegerAttr(idxType, pos));
+        loc, idxType, rewriter.getIntegerAttr(idxType, pos));
     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
                                                   constant);
   }
@@ -64,10 +63,9 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
                         const LLVMTypeConverter &typeConverter, Location loc,
                         Value val, Type llvmType, int64_t rank, int64_t pos) {
   if (rank <= 1) {
-    auto idxType = rewriter.getIndexType();
+    auto idxType = typeConverter.convertType(rewriter.getIndexType());
     auto constant = rewriter.create<LLVM::ConstantOp>(
-        loc, typeConverter.convertType(idxType),
-        rewriter.getIntegerAttr(idxType, pos));
+        loc, idxType, rewriter.getIntegerAttr(idxType, pos));
     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
                                                    constant);
   }
@@ -1064,10 +1062,9 @@ class VectorExtractElementOpConversion
 
     if (vectorType.getRank() == 0) {
       Location loc = extractEltOp.getLoc();
-      auto idxType = rewriter.getIndexType();
+      auto idxType = typeConverter->convertType(rewriter.getIndexType());
       auto zero = rewriter.create<LLVM::ConstantOp>(
-          loc, typeConverter->convertType(idxType),
-          rewriter.getIntegerAttr(idxType, 0));
+          loc, idxType, rewriter.getIntegerAttr(idxType, 0));
       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
           extractEltOp, llvmType, adaptor.getVector(), zero);
       return success();
@@ -1198,10 +1195,9 @@ class VectorInsertElementOpConversion
 
     if (vectorType.getRank() == 0) {
       Location loc = insertEltOp.getLoc();
-      auto idxType = rewriter.getIndexType();
+      auto idxType = typeConverter->convertType(rewriter.getIndexType());
       auto zero = rewriter.create<LLVM::ConstantOp>(
-          loc, typeConverter->convertType(idxType),
-          rewriter.getIntegerAttr(idxType, 0));
+          loc, idxType, rewriter.getIntegerAttr(idxType, 0));
       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
           insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
       return success();
@@ -1439,8 +1435,6 @@ class VectorTypeCastOpConversion
     if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
       return failure();
 
-    auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
-
     // Create descriptor.
     auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
     // Set allocated ptr.
@@ -1451,21 +1445,24 @@ class VectorTypeCastOpConversion
     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
     desc.setAlignedPtr(rewriter, loc, ptr);
     // Fill offset 0.
-    auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
-    auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
+
+    auto idxType = typeConverter->convertType(rewriter.getIndexType());
+    auto zero = rewriter.create<LLVM::ConstantOp>(
+        loc, idxType, rewriter.getIntegerAttr(idxType, 0));
     desc.setOffset(rewriter, loc, zero);
 
     // Fill size and stride descriptors in memref.
     for (const auto &indexedSize :
          llvm::enumerate(targetMemRefType.getShape())) {
       int64_t index = indexedSize.index();
-      auto sizeAttr =
-          rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
-      auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
+
+      auto size = rewriter.create<LLVM::ConstantOp>(
+          loc, idxType, rewriter.getIntegerAttr(idxType, indexedSize.value()));
       desc.setSize(rewriter, loc, index, size);
-      auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
-                                                (*targetStrides)[index]);
-      auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
+
+      auto stride = rewriter.create<LLVM::ConstantOp>(
+          loc, idxType,
+          rewriter.getIntegerAttr(idxType, (*targetStrides)[index]));
       desc.setStride(rewriter, loc, index, stride);
     }
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index e3a81bd20212d..c9b6a528c03b6 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -86,6 +86,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
 
   // Convert to the LLVM IR dialect.
   LowerToLLVMOptions options(&getContext());
+  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
+    options.overrideIndexBitwidth(indexBitwidth);
   LLVMTypeConverter converter(&getContext(), options);
   RewritePatternSet patterns(&getContext());
   populateVectorTransferLoweringPatterns(patterns);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-index-bitwidth.mlir b/mlir/test/Conversion/VectorToLLVM/vector-index-bitwidth.mlir
new file mode 100644
index 0000000000000..0869cd28b29b2
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-index-bitwidth.mlir
@@ -0,0 +1,674 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm='index-bitwidth=32' -split-input-file | FileCheck %s
+
+// CHECK-LABEL:   func.func @masked_reduce_add_f32_scalable(
+// CHECK-SAME:                                              %[[VAL_0:.*]]: vector<[16]xf32>,
+// CHECK-SAME:                                              %[[VAL_1:.*]]: vector<[16]xi1>) -> f32 {
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK:           %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK:           %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK:           %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_8:.*]] = "llvm.intr.vp.reduce.fadd"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32
+// CHECK:           return %[[VAL_8]] : f32
+// CHECK:         }
+func.func @masked_reduce_add_f32_scalable(%arg0: vector<[16]xf32>, %mask : vector<[16]xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[16]xf32> into f32 } : vector<[16]xi1> -> f32
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @masked_reduce_minf_f32_scalable(
+// CHECK-SAME:                                               %[[VAL_0:.*]]: vector<[16]xf32>,
+// CHECK-SAME:                                               %[[VAL_1:.*]]: vector<[16]xi1>) -> f32 {
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK:           %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK:           %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK:           %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_8:.*]] = "llvm.intr.vp.reduce.fmin"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32
+// CHECK:           return %[[VAL_8]] : f32
+// CHECK:         }
+func.func @masked_reduce_minf_f32_scalable(%arg0: vector<[16]xf32>, %mask : vector<[16]xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.reduction <minnumf>, %arg0 : vector<[16]xf32> into f32 } : vector<[16]xi1> -> f32
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @masked_reduce_add_i8_scalable(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: vector<[32]xi8>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: vector<[32]xi1>) -> i8 {
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK:           %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK:           %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK:           %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_8:.*]] = "llvm.intr.vp.reduce.add"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
+// CHECK:           return %[[VAL_8]] : i8
+// CHECK:         }
+func.func @masked_reduce_add_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
+  return %0 : i8
+}
+
+
+// -----
+
+// CHECK-LABEL:   func.func @masked_reduce_minui_i8_scalable(
+// CHECK-SAME:                                               %[[VAL_0:.*]]: vector<[32]xi8>,
+// CHECK-SAME:                                               %[[VAL_1:.*]]: vector<[32]xi1>) -> i8 {
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(-1 : i8) : i8
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK:           %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK:           %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK:           %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_8:.*]] = "llvm.intr.vp.reduce.umin"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
+// CHECK:           return %[[VAL_8]] : i8
+// CHECK:         }
+func.func @masked_reduce_minui_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <minui>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
+  return %0 : i8
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @masked_reduce_maxsi_i8_scalable(
+// CHECK-SAME:                                               %[[VAL_0:.*]]: vector<[32]xi8>,
+// CHECK-SAME:                                               %[[VAL_1:.*]]: vector<[32]xi1>) -> i8 {
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(-128 : i8) : i8
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK:           %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK:           %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK:           %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_8:.*]] = "llvm.intr.vp.reduce.smax"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
+// CHECK:           return %[[VAL_8]] : i8
+// CHECK:         }
+func.func @masked_reduce_maxsi_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <maxsi>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
+  return %0 : i8
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @masked_reduce_xor_i8_scalable(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: vector<[32]xi8>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: vector<[32]xi1>) -> i8 {
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK:           %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK:           %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK:           %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_8:.*]] = "llvm.intr.vp.reduce.xor"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
+// CHECK:           return %[[VAL_8]] : i8
+// CHECK:         }
+func.func @masked_reduce_xor_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <xor>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
+  return %0 : i8
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @shuffle_1D(
+// CHECK-SAME:                          %[[VAL_0:.*]]: vector<2xf32>,
+// CHECK-SAME:                          %[[VAL_1:.*]]: vector<3xf32>) -> vector<5xf32> {
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.poison : vector<5xf32>
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.extractelement %[[VAL_1]]{{\[}}%[[VAL_3]] : i32] : vector<3xf32>
+// CHECK:           %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_6:.*]] = llvm.insertelement %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_5]] : i32] : vector<5xf32>
+// CHECK:           %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_8:.*]] = llvm.extractelement %[[VAL_1]]{{\[}}%[[VAL_7]] : i32] : vector<3xf32>
+// CHECK:           %[[VAL_9:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_10:.*]] = llvm.insertelement %[[VAL_8]], %[[VAL_6]]{{\[}}%[[VAL_9]] : i32] : vector<5xf32>
+// CHECK:           %[[VAL_11:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_12:.*]] = llvm.extractelement %[[VAL_1]]{{\[}}%[[VAL_11]] : i32] : vector<3xf32>
+// CHECK:           %[[VAL_13:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_14:.*]] = llvm.insertelement %[[VAL_12]], %[[VAL_10]]{{\[}}%[[VAL_13]] : i32] : vector<5xf32>
+// CHECK:           %[[VAL_15:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_16:.*]] = llvm.extractelement %[[VAL_0]]{{\[}}%[[VAL_15]] : i32] : vector<2xf32>
+// CHECK:           %[[VAL_17:.*]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK:           %[[VAL_18:.*]] = llvm.insertelement %[[VAL_16]], %[[VAL_14]]{{\[}}%[[VAL_17]] : i32] : vector<5xf32>
+// CHECK:           %[[VAL_19:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_20:.*]] = llvm.extractelement %[[VAL_0]]{{\[}}%[[VAL_19]] : i32] : vector<2xf32>
+// CHECK:           %[[VAL_21:.*]] = llvm.mlir.constant(4 : i32) : i32
+// CHECK:           %[[VAL_22:.*]] = llvm.insertelement %[[VAL_20]], %[[VAL_18]]{{\[}}%[[VAL_21]] : i32] : vector<5xf32>
+// CHECK:           return %[[VAL_22]] : vector<5xf32>
+// CHECK:         }
+func.func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
+  %1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32>
+  return %1 : vector<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @extractelement_from_vec_0d_f32(
+// CHECK-SAME:                                              %[[VAL_0:.*]]: vector<f32>) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<f32> to vector<1xf32>
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.extractelement %[[VAL_1]]{{\[}}%[[VAL_2]] : i32] : vector<1xf32>
+// CHECK:           return %[[VAL_3]] : f32
+// CHECK:         }
+func.func @extractelement_from_vec_0d_f32(%arg0: vector<f32>) -> f32 {
+  %1 = vector.extractelement %arg0[] : vector<f32>
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @insertelement_into_vec_0d_f32(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: f32,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: vector<f32>) -> vector<f32> {
+// CHECK:           %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<f32> to vector<1xf32>
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.insertelement %[[VAL_0]], %[[VAL_2]]{{\[}}%[[VAL_3]] : i32] : vector<1xf32>
+// CHECK:           %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : vector<1xf32> to vector<f32>
+// CHECK:           return %[[VAL_5]] : vector<f32>
+// CHECK:         }
+func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> {
+  %1 = vector.insertelement %arg0, %arg1[] : vector<f32>
+  return %1 : vector<f32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @type_cast_f32(
+// CHECK-SAME:                             %[[VAL_0:.*]]: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
+// CHECK:           %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : memref<8x8x8xf32> to !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i32)>
+// CHECK:           %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK:           %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][0] : !llvm.struct<(ptr, ptr, i32)>
+// CHECK:           %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK:           %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_4]][1] : !llvm.struct<(ptr, ptr, i32)>
+// CHECK:           %[[VAL_7:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_6]][2] : !llvm.struct<(ptr, ptr, i32)>
+// CHECK:           %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : !llvm.struct<(ptr, ptr, i32)> to memref<vector<8x8x8xf32>>
+// CHECK:           return %[[VAL_9]] : memref<vector<8x8x8xf32>>
+// CHECK:         }
+func.func @type_cast_f32(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
+  %0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>
+  return %0 : memref<vector<8x8x8xf32>>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @type_cast_non_zero_addrspace(
+// CHECK-SAME:                                            %[[VAL_0:.*]]: memref<8x8x8xf32, 3>) -> memref<vector<8x8x8xf32>, 3> {
+// CHECK:           %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : memref<8x8x8xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<3>, ptr<3>, i32)>
+// CHECK:           %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK:           %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][0] : !llvm.struct<(ptr<3>, ptr<3>, i32)>
+// CHECK:           %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK:           %[[VAL_6:.*]] = llvm.insertvalue %[[VAL...
[truncated]

/*default=*/"vector::VectorTransformsOptions()",
"Options to lower some operations like contractions and transposes.">,
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! How this derive index bitwidth from data layout works? Is there a single bitwidth per module? It would be great if we could extend this to have a bitwidth per address space at some point :)

Copy link
Author

@quic-rb10 quic-rb10 Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcaballe Could you please elaborate on this? I have added some of my findings regarding the indexbitwidth from datalayout as a reply to another one of your comments.

@quic-rb10 quic-rb10 force-pushed the index_bit_width_vector_mlir branch 2 times, most recently from 102fcf0 to ae8d205 Compare March 3, 2025 20:07
Change-Id: I1ad6f77183f1f1faf25e935131de4ef3a4334150
@quic-rb10 quic-rb10 force-pushed the index_bit_width_vector_mlir branch from ae8d205 to 45d715a Compare March 6, 2025 16:30
@quic-rb10 quic-rb10 requested a review from dcaballe March 11, 2025 05:35
@javedabsar1 javedabsar1 self-requested a review March 11, 2025 09:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants